# pip install seaborn pandas 
import seaborn as sns
sns.set_theme(style="white")
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick


def figure2(data,title,palette):
    g = sns.catplot(
        data=data, kind="point",
        x="ai", y="sentence", hue='treat',
        errorbar=('ci',95), capsize=.05, palette=palette,
        linestyles=["-", "--"],
        linewidth=1.5, markers=['x', 'o'],
        alpha=1, height=4.5, aspect=1.5, dodge=True, legend=True
    )
    
    # Set x-axis and y-axis labels
    g.set_axis_labels("Decision-making Process", "Sentence (Month)")
    
    # Set x-axis display
    new_labels = ["Direct Decision", "AI-assisted Decision-making"]
    g.set_xticklabels(new_labels)
    
    # Set legend title        
    new_legend_labels = ['Control', 'Treatment (Trafficking)']
    for t, l in zip(g._legend.texts, new_legend_labels):
        t.set_text(l)
    
    sns.move_legend(g, "right", bbox_to_anchor=(1, .5), title='Groups')
    
    # Set title
    g.fig.suptitle(title)
    
    # Display the chart
    plt.savefig(f"output/figure2.png",dpi=300)

def figure3(data, palette):
    # Reset all styles
    plt.style.use('default')
    
    # Set seaborn theme and scale
    sns.set_theme(style="ticks", font="Times New Roman")
    sns.set_context("notebook", font_scale=1.5)  # font_scale can be adjusted, e.g. 1.5, 2.0
    
    # Create figure with 2 subplots stacked vertically
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(9, 10))
    
    # Remove right and top borders
    sns.despine(right=True, top=True)
    
    # Plot conviction rate (top subplot)
    sns.pointplot(
        data=data, ax=ax1,
        x="treat", y="conviction", hue='ai',
        errorbar=("ci", 95), capsize=.05, palette=palette,
        linestyles=["-", "--"],
        linewidth=2, markers=['x', 'o'],
        alpha=1, dodge=0.05
    )
    
    # Plot sentence length (bottom subplot)
    sns.pointplot(
        data=data, ax=ax2,
        x="treat", y="sentence", hue='ai',
        errorbar=("ci", 95), capsize=.05, palette=palette,
        linestyles=["-", "--"],
        linewidth=2, markers=['x', 'o'],
        alpha=1, dodge=0.05
    )
    
    # Configure both subplots
    for ax in [ax1, ax2]:
        ax.set_xlabel("Groups")
        ax.set_xticklabels(["Control", "Treatment (trafficking)"])
        
    # Get first legend and modify it
    legend = ax1.get_legend()
    new_legend_labels = ['Direct Decision', 'AI-assisted Decision-making']

    # Set legend position to upper left for both plots
    for ax in [ax1, ax2]:
        ax.legend(handles=legend.legend_handles,
                labels=new_legend_labels,
                loc='upper left', 
                bbox_to_anchor=(0.0, 1.1))
    # Set specific y-axis labels
    ax1.set_ylabel("Conviction Rate")
    ax2.set_ylabel("Sentence (months)")
    # set ax1 y as percentage
    ax1.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
    
    # Add panel labels
    ax1.text(-0.1, 1.1, 'Panel (a)', transform=ax1.transAxes)
    ax2.text(-0.1, 1.1, 'Panel (b)', transform=ax2.transAxes)
    
    # Adjust layout
    plt.tight_layout()
    
    # Save figure
    plt.savefig("output/figure3.png", dpi=300, bbox_inches='tight')
    
df=pd.read_stata("data.dta")
pastel_palette = ["#237B9F", "#AD0B08"]
gray_palette = ["#4D4D4D", "#999999"]

figure2(df,"Overall",gray_palette)
figure3(df,pastel_palette)